// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/// Compute the exploitability of a policy in a sequential zero-sum game.
///
/// This is the expected utility that a worst-case opponent achieves against the policy.
///
/// In order to identify the best response, we construct information sets, each
/// consisting of an array of pairs of states and their counterfactual reach
/// probabilities, i.e., the probability that the state would be reached if the best
/// responder always played to reach it. This is the product of the probabilities of
/// the necessary chance events and opponent action choices required to reach the node.
///
/// These probabilities give us the correct weighting for possible states of the
/// world when considering our best response for a particular information set.
///
/// The values we calculate are values of being in the specific state. Unlike in a
/// CFR algorithm, they are not weighted by reach probabilities. These values
/// take into account the whole state, so they may depend on information which is
/// unknown to the best-responding player.

/// A memoization cache with reference semantics backed by a dictionary.
fileprivate class MemoizingCache<Key: Hashable, Value> {
  var dict: [Key: Value] = [:]
  func cacheResult(for key: Key, in function: (Key) -> Value) -> Value {
    if let value = self.dict[key] {
      return value
    } else {
      let value = function(key)
      self.dict[key] = value
      return value
    }
  }
}

extension Sequence {
  /// The maximum value in a sequence when values are ordered by a unary sort key function.
  func max<T: Comparable>(key: (Element) -> T) -> Element? {
    return self.max { a, b in key(a) < key(b) }
  }
}

/// Policy representing the best response to a given stochastic policy in every state.
/// This is a unilateral strategy (i.e., it assumes that the opponents do not change
/// their policy).
public struct BestResponse<Game, Policy: StochasticPolicy>: DeterministicPolicy
  where Game == Policy.Game, Game.State.Game.Action == Game.Action {
  let player: Player
  let opponentPolicy: Policy

  /// Infosets (child states and their counterfactual reach probabilities) for all states.
  var informationSets: [String: [(Game.State, Double)]] = [:]
  /// A cache of state values for the best-responding player.
  /// These may depend on information not available to the best responder at that state.
  private let valueCache: MemoizingCache<Game.State, Double> = MemoizingCache()
  /// A cache of the best-response actions for every information state.
  private let bestResponseActionCache: MemoizingCache<String, Game.Action> = MemoizingCache()

  public init(game: Game, player: Player, policy: Policy) {
    self.player = player
    self.opponentPolicy = policy

    // Compute infosets by recursing over all states in the game tree.
    forEachChildState(from: game.initialState) { state, counterfactualReachProbability in
      informationSets[state.informationStateString(for: player), default: []].append(
        (state, counterfactualReachProbability))
    }
  }

  /// Recurse over all states accessible from `state` and apply function `body` to the state
  /// and its counterfactual reach probability.
  func forEachChildState(from state: Game.State, _ body: (Game.State, Double) -> ()) {
    if state.isTerminal { return }
    if state.currentPlayer == player {
      body(state, 1.0)
    }
    for (action, actionProbability) in counterfactualActionProbabilities(for: state) {
      forEachChildState(from: state.applying(action)) { childState, stateProbability in
        body(childState, actionProbability * stateProbability)
      }
    }
  }

  func counterfactualActionProbabilities(for state: Game.State) -> [Game.Action: Double] {
    switch state.currentPlayer {
    case player:
      // Counterfactual probabilities exclude the best responder's actions.
      return Dictionary(uniqueKeysWithValues: state.legalActions.map { action in (action, 1.0) })
    case .chance:
      return state.chanceOutcomes
    default:
      return opponentPolicy.actionProbabilities(forState: state)
    }
  }

  /// The expected utility for the best responder starting at a given state.
  func value(_ state: Game.State) -> Double {
    return valueCache.cacheResult(for: state) { state in
      switch state.currentPlayer {
      case .terminal:
        return state.utility(for: player)
      case player:
        return value(state.applying(action(forState: state)))
      default:
        return counterfactualActionProbabilities(for: state).map { action, probability in
          probability * value(state.applying(action))
        }.reduce(0, +)
      }
    }
  }

  /// The best response for a given information state.
  func bestResponseAction(_ informationState: String) -> Game.Action {
    return bestResponseActionCache.cacheResult(for: informationState) { informationState in
      let informationSet = informationSets[informationState]!
      // Get actions (arbitrarily) from the first state in the infoset.
      // Return the best action by counterfactual-reach-probability-weighted state-value.
      let legalActions = informationSet[0].0.legalActions
      return legalActions.max { action in
        informationSet.map { state, counterfactualReachProbability in
          counterfactualReachProbability * value(state.applying(action))
        }.reduce(0, +)
      }!
    }
  }

  /// The best response for every information state.
  public var bestResponseActions: [String: Game.Action] {
    return Dictionary(uniqueKeysWithValues: informationSets.keys.map { informationState in
      (informationState, bestResponseAction(informationState))
    })
  }

  public func action(forState state: Game.State) -> Game.Action {
    precondition(state.currentPlayer == player)
    return bestResponseAction(state.informationStateString(for: player))
  }
}

extension GameProtocol {
  /// A measure of how close a policy is to the game's Nash equilibrium.
  /// This is equal to the sum over all players of the improvements in value that
  /// each player could obtain by unilaterally changing their strategy to the best response,
  /// and is thus equal to the exploitability of the policy times the number of players
  /// for zero-sum games.
  /// See https://arxiv.org/abs/1711.00832 for more details.
  public func nashConv<Policy: StochasticPolicy>(_ policy: Policy) -> Double
    where Policy.Game == Self {
    guard let _ = utilitySum else {
      fatalError("nashConv is undefined on non-constant-sum games.")
    }
    let bestResponseValues = (0..<playerCount).map { playerID in
      BestResponse(game: self, player: .player(playerID), policy: policy).value(initialState)
    }
    let onPolicyValues = (0..<playerCount).map { playerID in
      value(for: .player(playerID), in: initialState, under: policy)
    }
    return zip(bestResponseValues, onPolicyValues).map(-).reduce(0, +)
  }
}
